{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Approximate bayesian inference using EP\n",
    "\n",
    "This notebooks demonstrates how to use the EP algorithm for pairwise data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import choix\n",
    "import numpy as np\n",
    "\n",
    "np.set_printoptions(precision=3, suppress=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "n_items = 8\n",
    "data = [\n",
    "    (7, 3), (2, 0), (5, 2), (4, 2), (2, 1),\n",
    "    (4, 5), (6, 3), (5, 4), (7, 0), (2, 3),\n",
    "    (4, 0), (0, 4), (6, 5), (3, 2), (3, 4),\n",
    "    (3, 4), (5, 2), (7, 3), (7, 6), (6, 5),\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Logit observation model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mean:\n",
      "[-1.638 -2.752 -1.011 -0.016 -0.94  -0.047  2.324  4.081]\n",
      "\n",
      "covariance:\n",
      "[[ 2.767  0.659  1.275  1.127  1.492  1.13   0.814  0.738]\n",
      " [ 0.659  5.367  1.033  0.689  0.681  0.695  0.485  0.391]\n",
      " [ 1.275  1.033  2.     1.333  1.318  1.345  0.939  0.757]\n",
      " [ 1.127  0.689  1.333  2.204  1.358  1.158  1.098  1.033]\n",
      " [ 1.492  0.681  1.318  1.358  2.043  1.368  0.959  0.782]\n",
      " [ 1.13   0.695  1.345  1.158  1.368  2.229  1.256  0.818]\n",
      " [ 0.814  0.485  0.939  1.098  0.959  1.256  3.022  1.426]\n",
      " [ 0.738  0.391  0.757  1.033  0.782  0.818  1.426  4.055]]\n",
      "\n",
      "total variance: 23.687252\n"
     ]
    }
   ],
   "source": [
    "mean, cov = choix.ep_pairwise(n_items, data, 0.1, model=\"logit\")\n",
    "print(\"mean:\")\n",
    "print(mean)\n",
    "print(\"\\ncovariance:\")\n",
    "print(cov)\n",
    "print(\"\\ntotal variance: {:3f}\".format(np.trace(cov)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Probit observation model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mean:\n",
      "[-1.198 -2.927 -0.797 -0.126 -0.742 -0.15   2.019  3.922]\n",
      "\n",
      "covariance:\n",
      "[[ 1.896  0.835  1.355  1.299  1.455  1.306  1.014  0.84 ]\n",
      " [ 0.835  4.451  0.99   0.857  0.843  0.839  0.654  0.53 ]\n",
      " [ 1.355  0.99   1.607  1.392  1.37   1.363  1.062  0.861]\n",
      " [ 1.299  0.857  1.392  1.708  1.373  1.294  1.124  0.953]\n",
      " [ 1.455  0.843  1.37   1.373  1.621  1.398  1.074  0.867]\n",
      " [ 1.306  0.839  1.363  1.294  1.398  1.705  1.192  0.903]\n",
      " [ 1.014  0.654  1.062  1.124  1.074  1.192  2.417  1.462]\n",
      " [ 0.84   0.53   0.861  0.953  0.867  0.903  1.462  3.584]]\n",
      "\n",
      "total variance: 18.990294\n"
     ]
    }
   ],
   "source": [
    "mean, cov = choix.ep_pairwise(n_items, data, 0.1, model=\"probit\")\n",
    "print(\"mean:\")\n",
    "print(mean)\n",
    "print(\"\\ncovariance:\")\n",
    "print(cov)\n",
    "print(\"\\ntotal variance: {:3f}\".format(np.trace(cov)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
